{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 微调chatglm\n",
    "\n",
    "\n",
    "1. 基于 data 目录下的数据训练 ChatGLM3 模型，使用 inference Notebook 对比微调前后的效果。\n",
    "2. （可选）：将 gen_dataset Notebook 改写为 py 文件。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Feb 17 18:12:24 2024       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA GeForce RTX 3090        On  | 00000000:C1:00.0 Off |                  N/A |\n",
      "| 32%   40C    P3             125W / 350W |      0MiB / 24576MiB |      0%      Default |\n",
      "|                                         |                      |                  N/A |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "|  No running processes found                                                           |\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/root'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "zhouyi_dataset_20240118_163659.csv\n"
     ]
    }
   ],
   "source": [
    "ls data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 参数定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# 定义全局变量和参数\n",
    "model_name_or_path = '/root/autodl-tmp/models/chatglm3-6b'  # 模型ID或本地路径\n",
    "# train_data_path = 'data/zhouyi_dataset_handmade.csv'    # 训练数据路径\n",
    "train_data_path = 'data/zhouyi_dataset_20240118_163659.csv'    # 训练数据路径(批量生成数据集）\n",
    "eval_data_path = None                     # 验证数据路径，如果没有则设置为None\n",
    "seed = 8                                 # 随机种子\n",
    "max_input_length = 512                    # 输入的最大长度\n",
    "max_output_length = 1536                  # 输出的最大长度\n",
    "lora_rank = 16                             # LoRA秩\n",
    "lora_alpha = 32                           # LoRA alpha值\n",
    "lora_dropout = 0.05                       # LoRA Dropout率\n",
    "prompt_text = ''                          # 所有数据前的指令文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 通用函数定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import ClassLabel, Sequence\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize_func 函数\n",
    "def tokenize_func(example, tokenizer, ignore_label_id=-100):\n",
    "    \"\"\"\n",
    "    对单个数据样本进行tokenize处理。\n",
    "\n",
    "    参数:\n",
    "    example (dict): 包含'content'和'summary'键的字典，代表训练数据的一个样本。\n",
    "    tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。\n",
    "    ignore_label_id (int, optional): 在label中用于填充的忽略ID，默认为-100。\n",
    "\n",
    "    返回:\n",
    "    dict: 包含'tokenized_input_ids'和'labels'的字典，用于模型训练。\n",
    "    \"\"\"\n",
    "\n",
    "    # 构建问题文本\n",
    "    question = prompt_text + example['content']\n",
    "    if example.get('input', None) and example['input'].strip():\n",
    "        question += f'\\n{example[\"input\"]}'\n",
    "\n",
    "    # 构建答案文本\n",
    "    answer = example['summary']\n",
    "\n",
    "    # 对问题和答案文本进行tokenize处理\n",
    "    q_ids = tokenizer.encode(text=question, add_special_tokens=False)\n",
    "    a_ids = tokenizer.encode(text=answer, add_special_tokens=False)\n",
    "\n",
    "    # 如果tokenize后的长度超过最大长度限制，则进行截断\n",
    "    if len(q_ids) > max_input_length - 2:  # 保留空间给gmask和bos标记\n",
    "        q_ids = q_ids[:max_input_length - 2]\n",
    "    if len(a_ids) > max_output_length - 1:  # 保留空间给eos标记\n",
    "        a_ids = a_ids[:max_output_length - 1]\n",
    "\n",
    "    # 构建模型的输入格式\n",
    "    input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)\n",
    "    question_length = len(q_ids) + 2  # 加上gmask和bos标记\n",
    "\n",
    "    # 构建标签，对于问题部分的输入使用ignore_label_id进行填充\n",
    "    labels = [ignore_label_id] * question_length + input_ids[question_length:]\n",
    "\n",
    "    return {'input_ids': input_ids, 'labels': labels}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import List, Dict, Optional\n",
    "\n",
    "# DataCollatorForChatGLM 类\n",
    "class DataCollatorForChatGLM:\n",
    "    \"\"\"\n",
    "    用于处理批量数据的DataCollator，尤其是在使用 ChatGLM 模型时。\n",
    "\n",
    "    该类负责将多个数据样本（tokenized input）合并为一个批量，并在必要时进行填充(padding)。\n",
    "\n",
    "    属性:\n",
    "    pad_token_id (int): 用于填充(padding)的token ID。\n",
    "    max_length (int): 单个批量数据的最大长度限制。\n",
    "    ignore_label_id (int): 在标签中用于填充的ID。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):\n",
    "        \"\"\"\n",
    "        初始化DataCollator。\n",
    "\n",
    "        参数:\n",
    "        pad_token_id (int): 用于填充(padding)的token ID。\n",
    "        max_length (int): 单个批量数据的最大长度限制。\n",
    "        ignore_label_id (int): 在标签中用于填充的ID，默认为-100。\n",
    "        \"\"\"\n",
    "        self.pad_token_id = pad_token_id\n",
    "        self.ignore_label_id = ignore_label_id\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"\n",
    "        处理批量数据。\n",
    "\n",
    "        参数:\n",
    "        batch_data (List[Dict[str, List]]): 包含多个样本的字典列表。\n",
    "\n",
    "        返回:\n",
    "        Dict[str, torch.Tensor]: 包含处理后的批量数据的字典。\n",
    "        \"\"\"\n",
    "        # 计算批量中每个样本的长度\n",
    "        len_list = [len(d['input_ids']) for d in batch_data]\n",
    "        batch_max_len = max(len_list)  # 找到最长的样本长度\n",
    "\n",
    "        input_ids, labels = [], []\n",
    "        for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):\n",
    "            pad_len = batch_max_len - len_of_d  # 计算需要填充的长度\n",
    "            # 添加填充，并确保数据长度不超过最大长度限制\n",
    "            ids = d['input_ids'] + [self.pad_token_id] * pad_len\n",
    "            label = d['labels'] + [self.ignore_label_id] * pad_len\n",
    "            if batch_max_len > self.max_length:\n",
    "                ids = ids[:self.max_length]\n",
    "                label = label[:self.max_length]\n",
    "            input_ids.append(torch.LongTensor(ids))\n",
    "            labels.append(torch.LongTensor(label))\n",
    "\n",
    "        # 将处理后的数据堆叠成一个tensor\n",
    "        input_ids = torch.stack(input_ids)\n",
    "        labels = torch.stack(labels)\n",
    "\n",
    "        return {'input_ids': input_ids, 'labels': labels}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>content</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>需卦在周易中怎样表达教育的概念？</td>\n",
       "      <td>需卦是一个占卜卦象，在周易卦象中，它由下卦乾和上卦坎组成。坎象征着云，乾象征着天，云聚于天，形成了等待时机的卦象。这是一个大吉大利的卜问，特别适合涉水渡河。根据《象辞》，君子观此卦象，可以宴饮安乐，待时而动。\\n\\n需卦的核心哲学是：稳扎稳打，不可冒失行动，观时待变，耐心等待。在事业上，需要审时度势，守中正，不可急进，自信充满，可化险为夷。在经商中，必须充满耐心，创造条件和机会，行事光明磊落，等到时机成熟后必然一帆风顺。在感情方面，亦需慎重，培养感情，以诚实、热情相待，时机成熟后可以有良好的结果。\\n\\n需卦紧随坤宫游魂卦，预示着踌躇和期待。得此卦者，需要时机尚未成熟，需要耐心等待，急进反会见凶险。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>周易中的比卦含义是什么？</td>\n",
       "      <td>在周易中，比卦代表着相亲相依的意涵，它是由上卦坎（水）和下卦坤（地）相叠而成。这一卦象预示着长期的吉利和无咎，同时也暗示着不愿臣服的邦国迟迟不来朝会有难。\\n\\n比卦的核心哲学是：水附大地，地纳河海，象征相亲相依，亲密无间，展示出宽宏无私，精诚团结的道理。因此，在这一卦象中，人们可以得到贵人的提拔，事业可望成功，但需诚实、信任地做事，待人宽厚、正直，主动热情。在经商中，也需真诚交往，遵守商业道德，不可贪心不足或自以为是。\\n\\n比卦的运势平顺，事业顺利可望成功，可得贵人提拔。经商方面，愿望能够实现且有利润，但需与他人密切合作，讲究商业道德。在婚恋方面，象征着美好姻缘和相亲相爱。在决策中，建议心地善良，待人忠诚、厚道，工作勤恳并善于选择朋友。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>请描述蒙卦的含义。</td>\n",
       "      <td>蒙卦是由艮卦（山）下，坎卦（水）上组成的异卦相叠。它代表着通泰，启蒙的意义。在这里，卜者并非是在向幼稚愚昧的人取求，而是幼稚愚昧的人在向卜者求教。第一次卜筮就得到了神灵的指示。然而，如果轻慢不敬地再三卜筮的话，神灵便不会再示警。总的来说，这是一个吉利的卜问。\\n\\n蒙卦的核心在于山下有泉的形象，寓意着启蒙。君子观此卦象，应当以果敢坚毅的行动来培养自身的品德，像山泉一样果断行动。然而，此卦乃是离宫四世卦，它代表着回还往复、疑惑不前、多忧愁过失，因而属于凶卦。\\n\\n蒙卦在个人发展、事业经商、求名婚恋等方面的解释不一。在事业方面，表示事业初建，具有启蒙和通达之象，需要勇敢坚毅的行动；而在经商方面，需要务必小心谨慎，树立高尚的商业道德，不可急功近利；求名方面，需要接受良好的基础教育，陶冶情操。整体而言，此卦提示须忍耐待机而动，听取别人意见，方能通达运势。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>周易的乾卦讲述了什么？</td>\n",
       "      <td>在周易中，乾卦是六十四卦之首，由六个阳爻组成，象征着天。它所代表的是刚健、健行、刚健不屈的意境。乾卦的核心哲学是：天道刚健，运行不已，君子观此卦象，从而以天为法，自强不息。\\n\\n乾卦象征天，为大通而至正。得此卦者，名利双收，应把握机会，争取成果。然而，切勿过于骄傲自满，而应保持谦逊、冷静和警惕。在事业、经商、求名等方面，乾卦皆暗示着大吉大利，但也警示着必须坚持正道、修养德行，方能永远亨通。\\n\\n在婚恋方面，乾卦提示着阳盛阴衰，但也强调刚柔相济，相互补足，形成美满的结果。在决策方面，则是强调刚健、正直、公允，自强不息的实质，需要修养德行、坚定信念，方能克服困难，消除灾难。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>周易中的讼卦含义是什么？</td>\n",
       "      <td>在周易中，讼卦是一个充满警示的卦象。它由上卦乾（天）和下卦坎（水）组成，代表着天与水背道而驰，形成争讼的局面。虽然事情开始时有利可图，但必须警惕戒惧，因为中间虽然吉利，但最终会带来凶险。对于涉及大川，涉水渡河的行动不利。因此，君子观此卦象，应当慎之又慎，杜绝争讼之事，并在谋事之初谨慎行事。讼卦的核心哲学是要避免争讼，退而让人，求得化解，安于正理，方可避免意外之灾。在事业上，务必避免介入诉讼纠纷的争执之中，与其这样，不如退而让人。即使最终获胜，也难免得失不均。经商方面，要坚持公正、公平、互利的原则，避免冲突，这样会有好结果。而对于求名、婚恋和决策，也都需要慎重行事，避免盲目追求，退让让人，可助事业、婚姻和决策的发展。</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"csv\", data_files=train_data_path)\n",
    "show_random_elements(dataset[\"train\"], num_examples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = dataset['train'].column_names\n",
    "tokenized_dataset = dataset['train'].map(\n",
    "    lambda example: tokenize_func(example, tokenizer),\n",
    "    batched=False, \n",
    "    remove_columns=column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenized_dataset.shuffle(seed=seed)\n",
    "tokenized_dataset = tokenized_dataset.flatten_indices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据整理器\n",
    "data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c60efcf6879e4b5187da53f4cfd57712",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "from transformers import AutoModel, BitsAndBytesConfig\n",
    "\n",
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n",
    "# 加载量化后模型\n",
    "model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                  quantization_config=q_config,\n",
    "                                  device_map='auto',\n",
    "                                  trust_remote_code=True)\n",
    "\n",
    "model.supports_gradient_checkpointing = True  \n",
    "model.gradient_checkpointing_enable()\n",
    "model.enable_input_require_grads()\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3739.69MiB\n"
     ]
    }
   ],
   "source": [
    "# 获取当前模型占用的 GPU显存（差值为预留给 PyTorch 的显存）\n",
    "memory_footprint_bytes = model.get_memory_footprint()\n",
    "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2)  # 转换为 MiB\n",
    "\n",
    "print(f\"{memory_footprint_mib:.2f}MiB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Feb 17 18:12:36 2024       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA GeForce RTX 3090        On  | 00000000:C1:00.0 Off |                  N/A |\n",
      "| 32%   43C    P2             134W / 350W |   4408MiB / 24576MiB |      0%      Default |\n",
      "|                                         |                      |                  N/A |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 3,899,392 || all params: 6,247,483,392 || trainable%: 0.06241540401681151\n"
     ]
    }
   ],
   "source": [
    "from peft import TaskType, LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING\n",
    "\n",
    "kbit_model = prepare_model_for_kbit_training(model)\n",
    "target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "    target_modules=target_modules,\n",
    "    r=lora_rank,\n",
    "    lora_alpha=lora_alpha,\n",
    "    lora_dropout=lora_dropout,\n",
    "    bias='none',\n",
    "    inference_mode=False,\n",
    "    task_type=TaskType.CAUSAL_LM\n",
    ")\n",
    "\n",
    "qlora_model = get_peft_model(kbit_model, lora_config)\n",
    "qlora_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.量化后的模型进行微调"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "\n",
    "train_epochs = 3\n",
    "output_dir = f\"models/{model_name_or_path}-epoch{train_epochs}-{timestamp}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,                            # 输出目录\n",
    "    per_device_train_batch_size=4,                     # 每个设备的训练批量大小\n",
    "    gradient_accumulation_steps=1,                     # 梯度累积步数\n",
    "    learning_rate=1e-3,                                # 学习率\n",
    "    num_train_epochs=train_epochs,                     # 训练轮数\n",
    "    lr_scheduler_type=\"linear\",                        # 学习率调度器类型\n",
    "    warmup_ratio=0.1,                                  # 预热比例\n",
    "    logging_steps=1,                                 # 日志记录步数\n",
    "    save_strategy=\"steps\",                             # 模型保存策略\n",
    "    save_steps=10,                                    # 模型保存步数\n",
    "    optim=\"adamw_torch\",                               # 优化器类型\n",
    "    fp16=True,                                        # 是否使用混合精度训练\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "        model=qlora_model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_dataset,\n",
    "        data_collator=data_collator\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='120' max='120' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [120/120 01:51, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.665100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>4.679700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>4.691400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>4.578200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>4.222900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.933600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>3.975000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.800300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>3.380300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>3.291800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>3.120800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>3.087300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>3.043300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>2.613800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>2.244600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>2.005500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>1.748500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>2.146100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>1.262100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.591300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>1.469300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>1.001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.759700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.592200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.431500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.299800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.530600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.425600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.157200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.245800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.136800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.131100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.042500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.090800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.039200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.059400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.056600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.042200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.030000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.025300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.021400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.019000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.018800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.019200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.018100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.012300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.011000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.015500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.011700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.008400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.005500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.005700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.005400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.004800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.003900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.003900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.004200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.003800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.003500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.002700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.003800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.002400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.002900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.001900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.001800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>0.004100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.002200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.002100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.002200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>76</td>\n",
       "      <td>0.001800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>77</td>\n",
       "      <td>0.001900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>78</td>\n",
       "      <td>0.001700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>79</td>\n",
       "      <td>0.001800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.001500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>81</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>82</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>83</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>84</td>\n",
       "      <td>0.001600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>85</td>\n",
       "      <td>0.001600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>86</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>87</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>89</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>91</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>92</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>93</td>\n",
       "      <td>0.001500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>94</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>95</td>\n",
       "      <td>0.001500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>96</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>97</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>98</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>99</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>101</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>102</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>103</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>104</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>105</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>106</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>107</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>108</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>109</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>111</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>112</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>113</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>114</td>\n",
       "      <td>0.001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>115</td>\n",
       "      <td>0.001300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>116</td>\n",
       "      <td>0.001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>117</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>118</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>119</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=120, training_loss=0.5912682440277421, metrics={'train_runtime': 113.0808, 'train_samples_per_second': 4.245, 'train_steps_per_second': 1.061, 'total_flos': 4457292867059712.0, 'train_loss': 0.5912682440277421, 'epoch': 3.0})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/peft/utils/save_and_load.py:148: UserWarning: Could not find a config file in /root/autodl-tmp/models/chatglm3-6b - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "trainer.model.save_pretrained(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 模型推理对比"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 加载量化后未微调的模型 （需要重启环境）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bcc99b782004d618f7735dd0bc82eec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "ChatGLMForConditionalGeneration(\n",
       "  (transformer): ChatGLMModel(\n",
       "    (embedding): Embedding(\n",
       "      (word_embeddings): Embedding(65024, 4096)\n",
       "    )\n",
       "    (rotary_pos_emb): RotaryEmbedding()\n",
       "    (encoder): GLMTransformer(\n",
       "      (layers): ModuleList(\n",
       "        (0-27): 28 x GLMBlock(\n",
       "          (input_layernorm): RMSNorm()\n",
       "          (self_attention): SelfAttention(\n",
       "            (query_key_value): Linear4bit(in_features=4096, out_features=4608, bias=True)\n",
       "            (core_attention): CoreAttention(\n",
       "              (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "            (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
       "          )\n",
       "          (post_attention_layernorm): RMSNorm()\n",
       "          (mlp): MLP(\n",
       "            (dense_h_to_4h): Linear4bit(in_features=4096, out_features=27392, bias=False)\n",
       "            (dense_4h_to_h): Linear4bit(in_features=13696, out_features=4096, bias=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layernorm): RMSNorm()\n",
       "    )\n",
       "    (output_layer): Linear(in_features=4096, out_features=65024, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig\n",
    "\n",
    "model_name_or_path = '/root/autodl-tmp/models/chatglm3-6b'  # 模型ID或本地路径\n",
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n",
    "# 加载量化后模型\n",
    "base_model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                  quantization_config=q_config,\n",
    "                                  device_map='auto',\n",
    "                                  trust_remote_code=True)\n",
    "base_model.requires_grad_(False)\n",
    "base_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 加载微调后的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model_path = f\"models/{model_name_or_path}-epoch3-20240217_181236\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "config = PeftConfig.from_pretrained(peft_model_path)\n",
    "model = PeftModel.from_pretrained(base_model, peft_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, HTML\n",
    "def compare_chatglm_results(query):\n",
    "    base_response, base_history = base_model.chat(tokenizer, query)\n",
    "\n",
    "    inputs = tokenizer(query, return_tensors=\"pt\").to(0)\n",
    "    ft_out = model.generate(**inputs, max_new_tokens=512)\n",
    "    ft_response = tokenizer.decode(ft_out[0], skip_special_tokens=True)\n",
    "    \n",
    "    #print(f\"问题：{query}\\n\\n原始输出：\\n{base_response}\\n\\n\\nChatGLM3-6B微调后：\\n{ft_response}\")\n",
    "    \n",
    "    html_text = \"<p style='color: red; font-weight: bold;'>问题：{}</p><br/><br/><p style='color: green; font-weight: bold;'>原始输出：<br/>{}<br/><br/><br/></p><p style='color: gray; font-weight: bold;'>ChatGLM3-6B微调后：<br/>{}</p>\".format(query,base_response,ft_response)\n",
    "    display(HTML(html_text))\n",
    "    return base_response, ft_response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p style='color: red; font-weight: bold;'>问题：解释下乾卦是什么？</p><br/><br/><p style='color: green; font-weight: bold;'>原始输出：<br/>在周易中，乾卦是六十四卦之首，由六个阳爻组成，象征着天。它所代表的是刚健、健行、刚健不屈的意境。乾卦的核心哲学是：天道刚健，运行不已，君子观此卦象，从而以天为法，自强不息。\n",
       "\n",
       "乾卦象征天，为大通而至正。得此卦者，名利双收，应把握机会，争取成果。然而，切勿过于骄傲自满，而应保持谦逊、冷静和警惕。在事业、经商、求名等方面，乾卦皆暗示着大吉大利，但也警示着必须坚持正道、修养德行，方能永远亨通。\n",
       "\n",
       "在婚恋方面，乾卦提示着阳盛阴衰，但也强调刚柔相济，相互补足，形成美满的结果。在决策方面，则是强调刚健、正直、公允，自强不息的实质，需要修养德行、坚定信念，方能克服困难，消除灾难。<br/><br/><br/></p><p style='color: gray; font-weight: bold;'>ChatGLM3-6B微调后：<br/>[gMASK]sop 解释下乾卦是什么？ 在周易中，乾卦是六十四卦之首，由六个阳爻组成，象征着天。它所代表的是刚健、健行、刚健不屈的意境。乾卦的核心哲学是：天道刚健，运行不已，君子观此卦象，从而以天为法，自强不息。\n",
       "\n",
       "乾卦象征天，为大通而至正。得此卦者，名利双收，应把握机会，争取成果。然而，切勿过于骄傲自满，而应保持谦逊、冷静和警惕。在事业、经商、求名等方面，乾卦皆暗示着大吉大利，但也警示着必须坚持正道、修养德行，方能永远亨通。\n",
       "\n",
       "在婚恋方面，乾卦提示着阳盛阴衰，但也强调刚柔相济，相互补足，形成美满的结果。在决策方面，则是强调刚健、正直、公允，自强不息的实质，需要修养德行、坚定信念，方能克服困难，消除灾难。</p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(query=\"解释下乾卦是什么？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p style='color: red; font-weight: bold;'>问题：地水师卦是什么？</p><br/><br/><p style='color: green; font-weight: bold;'>原始输出：<br/>在周易中，地水师卦是一个由三个卦相叠而成的卦象，代表着自然的秩序和规律。它象征着大地上的河流和水源，代表着润泽和滋养。这个卦象代表着一种安全、稳定和吉祥的状态，但也提醒着要小心谨慎，避免危险。\n",
       "\n",
       "地水师卦的核心哲学是：君子观此卦象，应当充满敬畏之心，严格遵循自然的规律，以善良和正直的行为去适应和处理问题。在这个卦象中，自然界的秩序和规律被强调，要求人们顺应自然，才能获得幸福和成功。\n",
       "\n",
       "地水师卦的卦象是由上卦坎（水）下卦坤（地）相叠而成。坎卦代表险池，下卦坤卦代表地基， both 卦象代表着危险和危险相反。然而，尽管卦象本身具有危险，但在其中隐藏着大吉大利的消息。\n",
       "\n",
       "地水师卦的运势：在事业、经商、求名、婚恋等方面的运势都有所提高。在经商方面，应当谨慎行事，遵守道德原则，远离的成功和幸福之兆。在求名方面，有望获得好的成绩和高度赞扬。在婚恋方面，建议仔细选择伴侣，务必慎重。\n",
       "\n",
       "地水师卦的建议：面对困难的选择时，应当谨慎思考，遵循善良和正直的原则，妥善处理问题。同时，要善于利用周围资源和机会，努力求得解决方案，务必坚定信念，勇于行动。<br/><br/><br/></p><p style='color: gray; font-weight: bold;'>ChatGLM3-6B微调后：<br/>[gMASK]sop 地水师卦是什么？ 在周易中，地水师卦是一个由坎卦（水）和坤卦（地）相叠而成的异卦。这一卦象代表着军队的力量和军情的总指挥，预示着吉祥无灾。象辞中描述了地水师卦的特点，以及君子应当如何行动。\n",
       "\n",
       "卦象中，上卦为坤卦，代表地，下卦为坎卦，代表水。坤卦代表柔顺、地气舒展之象，坎卦代表柔顺、地气舒展之象。两卦相互融合，形成mississippi（混然一体）之象，因此得名地水师卦。\n",
       "\n",
       "君子观此卦象，应当以柔顺、地气舒展之态度来处理事情，务必moderation（适度）行事。地水师卦象征着回旋的波浪，预示着顺从、耐心、等待，而非强制行动。在事业、经商、求名、婚恋等方面的决策中，都需遵循mississippi之原则，务必耐心等待，顺从自然，方可成功。</p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(query=\"地水师卦是什么？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p style='color: red; font-weight: bold;'>问题：周易中的讼卦是什么</p><br/><br/><p style='color: green; font-weight: bold;'>原始输出：<br/>在周易中，讼卦是一个充满警示的卦象。它由上卦乾（天）和下卦坎（水）组成，代表着天与水背道而驰，形成争讼的局面。虽然事情开始时有利可图，但必须警惕戒惧，因为中间虽然吉利，但最终会带来凶险。对于涉及大川，涉水渡河的行动不利。因此，君子观此卦象，应当慎之又慎，杜绝争讼之事，并在谋事之初谨慎行事。讼卦的核心哲学是要避免争讼，退而让人，求得化解，安于正理，方可避免意外之灾。在事业上，务必避免介入诉讼纠纷的争执之中，与其这样，不如退而让人。即使最终获胜，也难免得失不均。经商方面，要坚持公正、公平、互利的原则，避免冲突，这样会有好结果。而对于求名、婚恋和决策，也都需要慎重行事，避免盲目追求，退让让人，可助事业、婚姻和决策的发展。<br/><br/><br/></p><p style='color: gray; font-weight: bold;'>ChatGLM3-6B微调后：<br/>[gMASK]sop 周易中的讼卦是什么 在周易中，讼卦是一个充满警示的卦象。它由上卦乾（天）和下卦坎（水）组成，代表着天与水背道而驰，形成争讼的局面。虽然事情开始时有利可图，但必须警惕戒惧，因为中间虽然吉利，但最终会带来凶险。对于涉及大川，涉水渡河的行动不利。因此，君子观此卦象，应当慎之又慎，杜绝争讼之事，并在谋事之初谨慎行事。讼卦的核心哲学是要避免争讼，退而让人，求得化解，安于正理，方可避免意外之灾。在事业上，务必避免介入诉讼纠纷的争执之中，与其这样，不如退而让人。即使最终获胜，也难免得失不均。经商方面，要坚持公正、公平、互利的原则，避免冲突，这样会有好结果。而对于求名、婚恋和决策，也都需要慎重行事，避免盲目追求，退让让人，可助事业、婚姻和决策的发展。</p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(query=\"周易中的讼卦是什么\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
